Azure WAF/Alert - Process Azure FrontDoor Alerts/ProcessAFDAlerts.cs (294 lines of code) (raw):
// //////////////////////////////////////////////////////////////////////////////
//
// Copyright (C) Microsoft Corporation. All rights reserved.
//
// //////////////////////////////////////////////////////////////////////////////
namespace processAfdAlerts
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Azure.Identity;
using Azure.Monitor.Query;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Azure.Management.FrontDoor;
using Microsoft.Azure.Management.FrontDoor.Models;
using Microsoft.Azure.Management.ResourceManager.Fluent;
using Microsoft.Azure.WebJobs;
using Microsoft.Azure.WebJobs.Extensions.Http;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
/// <summary>
/// Azure function to process alerts from Azure front door
/// </summary>
public static class ProcessAfdAlerts
{
// consts
private const string MitigateDDOSRateLimitCountryRuleNamePostfix = "MitigateDDOSRateLimitCountryRule";
private const string MitigateDDOSRateLimitTopRequestIPsRuleNamePostfix = "MitigateDDOSRateLimitTopRequestIPsRule";
// Info to get logs from Log analytics
private static readonly string ClientId = "";
private static readonly string ClientSecret = "";
private static string TenantId = "";
// Linked WAF Policy Info
private static readonly string WafPolicyName = "";
private static readonly string WafPolicyResourceGroupName = "";
private static readonly string WafPolicySubscriptionId = "";
// Frontdoor resourceId
private static readonly string FrontdoorResourceId = "";
[FunctionName("ProcessAfdAlerts")]
public static async Task<IActionResult> Run(
[HttpTrigger(AuthorizationLevel.Function, "get", "post", Route = null)]
HttpRequest req,
ILogger log)
{
// 1. Parse the alert message from the incoming request's body
var requestBody = await new StreamReader(req.Body).ReadToEndAsync();
AlertBody alertBody;
try
{
alertBody = JsonConvert.DeserializeObject<AlertBody>(requestBody);
}
catch (Exception e)
{
return new BadRequestObjectResult("Failed to deserialize the request body");
}
if (alertBody == null)
{
return new BadRequestObjectResult("AlertBody is null");
}
// 2. Extract info from the alert
var alertInfo = new AlertInfo()
{
Country = alertBody.data.alertContext.condition.allOf[0].dimensions[0].value
};
// 3. Check if the alert was fired/activated or resolved/deactivated
if (alertBody.data.essentials.monitorCondition == "Fired")
{
alertInfo.baselineThreshold = Convert.ToInt32(Convert.ToDouble(alertBody.data.alertContext.condition.allOf[0].threshold)) + 1;
await HandleAlertFired(log, alertInfo);
}
else
{
await HandleAlertResolved(log, alertInfo);
}
return new OkObjectResult("All done ... ");
}
private static async Task HandleAlertResolved(ILogger log, AlertInfo alertInfo)
{
// 1. Delete/disable the country specific rules from the WAF policy
await UpdateLinkedWafPolicy(alertInfo, log, null, true).ConfigureAwait(false);
}
private static async Task HandleAlertFired(ILogger log, AlertInfo alertInfo)
{
// 1. Query the logs for the past 10 minutes
var logs = await GetLogs(log, alertInfo);
if (logs == null)
{
// nothing to do
return;
}
// 2. update the linked waf policy with new rules to mitigate the attack
await UpdateLinkedWafPolicy(alertInfo, log, logs).ConfigureAwait(false);
}
private static async Task UpdateLinkedWafPolicy(AlertInfo alertInfo, ILogger log, IEnumerable<Row> logs = null, bool deleteRules = false)
{
// 1. Create an instance of FrontdoorManagementClient
var frontdoorClient =
new FrontDoorManagementClient(
SdkContext.AzureCredentialsFactory.FromServicePrincipal(ClientId, ClientSecret, TenantId, AzureEnvironment.AzureGlobalCloud));
frontdoorClient.SubscriptionId = WafPolicySubscriptionId;
// 2. Use it to get the WAF Policy
WebApplicationFirewallPolicy wafPolicy;
try
{
wafPolicy = await frontdoorClient.Policies.GetAsync(WafPolicyResourceGroupName, WafPolicyName);
if (wafPolicy == null)
{
Console.WriteLine("Does not exist");
throw new Exception("Waf policy does not exist");
}
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
throw;
}
// 3. Add or remove the rate limit rules from the WAF policy
if (deleteRules)
{
DeleteRulesFromWafPolicy(wafPolicy, alertInfo);
}
else
{
UpdateWafWithRulesToStopAttack(wafPolicy, alertInfo, logs);
}
// 4. Update/Deploy the WAF policy
try
{
await frontdoorClient.Policies.CreateOrUpdateAsync(WafPolicyResourceGroupName, WafPolicyName,
wafPolicy);
}
catch (Exception e)
{
log.LogError("Failed to update waf policy", e);
throw;
}
}
private static void DeleteRulesFromWafPolicy(WebApplicationFirewallPolicy wafPolicy, AlertInfo alertInfo)
{
DeleteRateLimitCountryRule(wafPolicy, alertInfo);
DeleteRateLimitIPRule(wafPolicy, alertInfo);
}
private static void DeleteRateLimitCountryRule(WebApplicationFirewallPolicy wafPolicy, AlertInfo alertInfo)
{
// 1. Check if the rule already exists
var countryRateLimitRule = GetCountryRateLimitRule(wafPolicy, alertInfo.Country);
// 2 If not, create it
if (countryRateLimitRule != null)
{
wafPolicy.CustomRules.Rules.Remove(countryRateLimitRule);
}
}
private static void DeleteRateLimitIPRule(WebApplicationFirewallPolicy wafPolicy, AlertInfo alertInfo)
{
// 1. Check if rule already exists
var rateLimitIpsRule = GetRateLimitIPRuleByCountry(wafPolicy, alertInfo.Country);
// 2. If yes, delete it
if (rateLimitIpsRule != null)
{
wafPolicy.CustomRules.Rules.Remove(rateLimitIpsRule);
}
}
private static void UpdateWafWithRulesToStopAttack(
WebApplicationFirewallPolicy wafPolicy,
AlertInfo alertInfo,
IEnumerable<Row> logs)
{
// 1. Create or update a rule to rate limit country traffic
CreateOrUpdateRateLimitCountryRule(wafPolicy, alertInfo);
// 2. Create or update a rule to rate limit the IPs sending requests over dynamically detected baseline
CreateOrUpdateRateLimitIpsRule(wafPolicy, alertInfo, logs);
}
private static void CreateOrUpdateRateLimitIpsRule(
WebApplicationFirewallPolicy wafPolicy,
AlertInfo alertInfo,
IEnumerable<Row> logs)
{
if (logs == null || !logs.Any())
{
return;
}
// 1. Check if rule already exists
var rateLimitIpsRule = GetRateLimitIPRuleByCountry(wafPolicy, alertInfo.Country);
// 2. If not, create it
if (rateLimitIpsRule == null)
{
// rule does not exist, create it
rateLimitIpsRule = new CustomRule(
GetIpRateLimitPriorityByCountry(alertInfo.Country),
RuleType.RateLimitRule,
new List<MatchCondition>
{
new MatchCondition("RemoteAddr", "IPMatch", new List<string>())
},
"Block");
rateLimitIpsRule.Name = $"{GetCountryCode(alertInfo.Country)}{MitigateDDOSRateLimitTopRequestIPsRuleNamePostfix}";
rateLimitIpsRule.RateLimitDurationInMinutes = 5;
wafPolicy.CustomRules.Rules.Add(rateLimitIpsRule);
}
// 3. Update the list of IPs to be blocked based on the most recent data
var listOfIPsToRateLimit = logs.Select(r => r.clientIp_s).ToList();
rateLimitIpsRule.MatchConditions[0].MatchValue = listOfIPsToRateLimit;
rateLimitIpsRule.RateLimitThreshold = alertInfo.baselineThreshold < 100 ? 100 : alertInfo.baselineThreshold;
rateLimitIpsRule.EnabledState = "Enabled";
}
private static void CreateOrUpdateRateLimitCountryRule(WebApplicationFirewallPolicy wafPolicy, AlertInfo alertInfo)
{
// 1. Check if the rule already exists
var countryRateLimitRule = GetCountryRateLimitRule(wafPolicy, alertInfo.Country);
// 2 If not, create it
if (countryRateLimitRule == null)
{
countryRateLimitRule = CreateCountryRateLimitRule(alertInfo);
wafPolicy.CustomRules.Rules.Add(countryRateLimitRule);
}
// 3. Update the rule with the new dynamic threshold baseline and enable it
countryRateLimitRule.EnabledState = "Enabled";
countryRateLimitRule.RateLimitThreshold =
10 * alertInfo.baselineThreshold < 1000 ? 1000 : 10 * alertInfo.baselineThreshold;
}
private static CustomRule GetRateLimitIPRuleByCountry(WebApplicationFirewallPolicy wafPolicy, string alertInfoCountry)
{
return wafPolicy.CustomRules.Rules?.FirstOrDefault(rule =>
rule.Name == $"{GetCountryCode(alertInfoCountry)}{MitigateDDOSRateLimitTopRequestIPsRuleNamePostfix}");
}
private static CustomRule CreateCountryRateLimitRule(AlertInfo alertInfo)
{
var alertCountryCode = GetCountryCode(alertInfo.Country);
return new CustomRule(
GetCountryRateLimitPriority(alertCountryCode),
RuleType.RateLimitRule,
new List<MatchCondition>
{
new MatchCondition("RemoteAddr", "GeoMatch", new List<string> { alertCountryCode })
},
"Block")
{
Name = $"{alertCountryCode}{MitigateDDOSRateLimitCountryRuleNamePostfix}",
RateLimitDurationInMinutes = 5
};
}
private static CustomRule GetCountryRateLimitRule(
WebApplicationFirewallPolicy webApplicationFirewallPolicy,
string alertCountry)
{
return webApplicationFirewallPolicy.CustomRules.Rules.FirstOrDefault(rule =>
rule.Name == $"{GetCountryCode(alertCountry)}{MitigateDDOSRateLimitCountryRuleNamePostfix}");
}
/// <summary>
/// Convert this and priority switch to a dict and enum
/// </summary>
/// <param name="country"></param>
/// <returns></returns>
private static string GetCountryCode(string country)
{
return country.ToLower() switch
{
"united states" => "US",
"unitedstates" => "US",
"canada" => "CA",
"brazil" => "BR",
"ireland" => "IE",
"australia" => "AU",
"singapore" => "SG",
"japan" => "JP",
"france" => "FR",
"india" => "IN",
_ => country
};
}
private static int GetCountryRateLimitPriority(string countryCode)
{
return countryCode switch
{
"US" => 1,
"CA" => 2,
"BR" => 3,
"IE" => 4,
"AU" => 5,
"SG" => 6,
"JP" => 7,
"FR" => 8,
"IN" => 9,
_ => new Random().Next(10, 1000)
};
}
private static int GetIpRateLimitPriorityByCountry(string country)
{
return country.ToLower() switch
{
"united states" => 1001,
"unitedStates" => 1001,
"canada" => 1002,
"brazil" => 1003,
"ireland" => 1004,
"australia" => 1005,
"singapore" => 1006,
"japan" => 1007,
"france" => 1008,
"india" => 1009,
_ => new Random().Next(1100, 2000)
};
}
private static async Task<IEnumerable<Row>> GetLogs(ILogger log, AlertInfo alertInfo)
{
// 1. Prepare the credentials to get logs, we need a aad app which has permissions to view the log
string workspaceId = "";
var credential = new ClientSecretCredential(TenantId, ClientId, ClientSecret);
var logsClient = new LogsQueryClient(credential);
var topIPs =
"AzureDiagnostics " +
"| where Category == \"FrontDoorAccessLog\" " +
$"| where _ResourceId == \"{FrontdoorResourceId}\"" +
$"| where clientCountry_s == \"{alertInfo.Country}\"" +
"| summarize requestCount = count() by clientIp_s" +
$"| where requestCount > {alertInfo.baselineThreshold}" +
"| order by requestCount desc";
try
{
var response = await logsClient.QueryWorkspaceAsync<Row>(
workspaceId,
topIPs,
new QueryTimeRange(TimeSpan.FromMinutes(20)));
return response.Value;
}
catch (Exception ex)
{
log.LogError($"Exception while querying log analytics workspace", ex);
throw;
}
}
}
}